# Copyright 2025 The OpenXLA Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Utilities for detecting checksum mismatches in buffer debug logs.

The log is generated by running with
--xla_gpu_experimental_enable_checksum_tracing_on_thunks.
"""

import collections
import dataclasses
import itertools
from typing import Callable, Iterable, NewType, Optional, Self, TypeVar

from xla.backends.gpu.runtime import buffer_debug_log_pb2
from xla.backends.gpu.runtime import thunk_pb2


ModuleExecutionId = NewType("ModuleExecutionId", int)
ThunkId = NewType("ThunkId", int)
BufferIdx = NewType("BufferIdx", int)
Checksum = NewType("Checksum", int)


@dataclasses.dataclass(frozen=True)
class BufferChecksums:
  """A set of buffer checksums with order-independent hashing."""

  checksums: dict[BufferIdx, Checksum]

  def __hash__(self):
    return hash(tuple(sorted(self.checksums.items())))


@dataclasses.dataclass(frozen=True)
class ThunkMetadata:
  """Thunk metadata, read from ThunkMetadataListProto.

  Stored in a separate type to enable type checking.
  """

  thunk_id: ThunkId
  thunk_kind: str
  profile_annotation: Optional[str]


@dataclasses.dataclass(frozen=True)
class ThunkExecution:
  """The details of a single execution of a thunk."""

  # An ID of the HLO module execution that produced this thunk execution.
  module_execution_id: int
  # An ID of the thunk execution within a HLO module execution. If a thunk
  # executes in a loop, there will create multiple entries with same thunk_id
  # but different execution IDs.
  thunk_execution_id: int
  # The ID of the thunk that was executed. Details about the thunk can be found
  # in ThunkMetadata.
  thunk_id: ThunkId
  # Checksums of buffers with defined contents before thunk execution.
  # These are used to identify repeats of the same computation that are expected
  # to produce the same results.
  input_checksums: BufferChecksums
  # Checksums of buffers with defined contents after thunk execution.
  # These are the values we want to verify are consistent across executions.
  output_checksums: BufferChecksums


@dataclasses.dataclass(frozen=True)
class ChecksumMismatchReport:
  """A report of checksum mismatches for a thunk."""

  thunk_metadata: dict[ThunkId, ThunkMetadata]
  # Thunks for which different executions produced different results. The value
  # is a input checksums => output checksum sets dict containing the info about
  # inconsistent outptus, and the checksums of inputs that caused them.
  mismatches: dict[
      ThunkId, dict[BufferChecksums, dict[BufferIdx, set[Checksum]]]
  ]

  @classmethod
  def from_protos(
      cls,
      log_protos: dict[
          ModuleExecutionId, buffer_debug_log_pb2.BufferDebugLogProto
      ],
      metadata_proto: thunk_pb2.ThunkMetadataListProto,
  ) -> Self:
    """Creates a ChecksumMismatchReport from protobufs.

    Args:
      log_protos: A dict of BufferDebugLogProto keyed by module execution ID.
      metadata_proto: A ThunkMetadataListProto.

    Preconditions:
      - All log protos must refer to the same HLO module.
      - metadata proto must describe the same HLO module as the log protos or be
        an empty proto.
    """
    metadata = _parse_metadata(metadata_proto)

    executions = itertools.chain.from_iterable(
        _parse_log(module_execution_id, log_proto)
        for module_execution_id, log_proto in log_protos.items()
    )
    mismatches = _find_inconsistent_thunks(executions)

    return cls(metadata, mismatches)


K = TypeVar("K")
T = TypeVar("T")


def group_by(
    values: Iterable[T], key_getter: Callable[[T], K]
) -> dict[K, list[T]]:
  """Groups a sequence by a key function."""
  result = collections.defaultdict(list)
  for item in values:
    result[key_getter(item)].append(item)
  return result


def _parse_metadata(
    metadata_proto: thunk_pb2.ThunkMetadataListProto,
) -> dict[ThunkId, ThunkMetadata]:
  """Parses a ThunkMetadataListProto into a dict of ThunkMetadata."""
  metadata_by_thunk_id: dict[ThunkId, ThunkMetadata] = {}
  for metadata in metadata_proto.thunk_metadata:
    thunk_id = ThunkId(metadata.thunk_info.thunk_id)
    metadata_by_thunk_id[thunk_id] = ThunkMetadata(
        thunk_id=thunk_id,
        thunk_kind=metadata.thunk_kind,
        profile_annotation=metadata.thunk_info.profile_annotation,
    )

  return metadata_by_thunk_id


def _parse_log(
    module_execution: int,
    log_proto: buffer_debug_log_pb2.BufferDebugLogProto,
) -> list[ThunkExecution]:
  """Parses a BufferDebugLogProto and ThunkMetadataListProto into a list of ThunkExecutions."""

  entries_by_execution = group_by(
      log_proto.entries, lambda entry: (entry.thunk_id, entry.execution_id)
  )
  executions = [
      ThunkExecution(
          module_execution_id=module_execution,
          thunk_execution_id=execution_id,
          thunk_id=thunk_id,
          input_checksums=BufferChecksums({
              entry.buffer_idx: entry.checksum
              for entry in entries
              if entry.is_input_buffer
          }),
          output_checksums=BufferChecksums({
              entry.buffer_idx: entry.checksum
              for entry in entries
              if not entry.is_input_buffer
          }),
      )
      for (thunk_id, execution_id), entries in entries_by_execution.items()
  ]
  return executions


def _find_inconsistent_output_checksums(
    executions: list[ThunkExecution],
) -> dict[BufferIdx, set[Checksum]]:
  """Finds mismatches in output checksums for a list of identical executions.

  Args:
    executions: A list of executions of the same thunk on the same input
      arguments.

  Returns:
    A dict of buffers whose contents were not consistent across executions with
    the same inputs, based on the checksum value. The value is a set of
    checksums observed for that buffer.
  """
  checksums_by_buffer_idx: dict[BufferIdx, set[Checksum]] = (
      collections.defaultdict(set)
  )
  for execution in executions:
    for buffer_idx, checksum in execution.output_checksums.checksums.items():
      checksums_by_buffer_idx[buffer_idx].add(checksum)

  return {
      buffer_idx: checksums
      for buffer_idx, checksums in checksums_by_buffer_idx.items()
      if len(checksums) > 1
  }


def _find_inconsistent_thunks(
    executions: Iterable[ThunkExecution],
) -> dict[ThunkId, dict[BufferChecksums, dict[BufferIdx, set[Checksum]]]]:
  """Finds thunks with inconsistent output checksums across identical executions.

  Args:
    executions: A arbitrary list of thunk executions.

  Returns:
    A dict of thunks whose outputs were inconsistent across identical
    executions.

    The value is a dict keyed by the set of input checksums, with values
    identifying the output buffers with inconsistent checksums, along with the
    set of observed checksums for each.
  """
  executions_by_thunk_id: dict[ThunkId, list[ThunkExecution]] = group_by(
      executions,
      lambda e: e.thunk_id,
  )

  mismatches: dict[
      ThunkId, dict[BufferChecksums, dict[BufferIdx, set[Checksum]]]
  ] = {}
  for thunk_id, executions in executions_by_thunk_id.items():
    executions_by_inputs: dict[BufferChecksums, list[ThunkExecution]] = (
        group_by(executions, lambda e: e.input_checksums)
    )

    mismatches_by_inputs: dict[
        BufferChecksums, dict[BufferIdx, set[Checksum]]
    ] = {}
    for input_checksums, executions in executions_by_inputs.items():
      m = _find_inconsistent_output_checksums(executions)
      if m:
        mismatches_by_inputs[input_checksums] = m

    if mismatches_by_inputs:
      mismatches[thunk_id] = mismatches_by_inputs

  return mismatches
